package cn.bugstack.openai.executor.model.baidu;

import cn.bugstack.openai.exception.OpenAiSdkException;
import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.baidu.config.BaiduConfig;
import cn.bugstack.openai.executor.model.baidu.valobj.Message;
import cn.bugstack.openai.executor.model.baidu.valobj.Usage;
import cn.bugstack.openai.executor.model.baidu.valobj.*;
import cn.bugstack.openai.executor.parameter.*;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.chatplus.application.common.util.PlusJsonUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * 百度文心一言大模型
 */
@Slf4j
public class BaiduModelExecutor implements Executor, ParameterHandler<BaiduCompletionRequest>, ResultHandler {

    private final EventSource.Factory factory;

    private final OkHttpClient okHttpClient;

    public BaiduModelExecutor(Configuration configuration) {
        this.okHttpClient = configuration.getOkHttpClient();
        this.factory = configuration.createRequestFactory();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        return completions(null, null, completionRequest, eventSourceListener);
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // 1. 核心参数校验；不对用户的传参做更改，只返回错误信息。
        if (!completionRequest.isStream()) {
            throw new OpenAiSdkException("illegal parameter stream is false!");
        }
        // 4. 转换参数信息
        BaiduCompletionRequest baiduCompletionRequest = getParameterObject(completionRequest);
        MediaType mediaType = MediaType.parse(Configuration.APPLICATION_JSON);
        String url = BaiduConfig.CompletionsUrl.fromCode(completionRequest.getModel()).getUrl();
        Request request = new Request.Builder()
                .tag(completionRequest.getTag())
                .addHeader("Content-Type", Configuration.APPLICATION_JSON)
                .url(BaiduConfig.API_HOST.concat(url))
                .post(RequestBody.create(baiduCompletionRequest.toString(), mediaType))
                .build();

        return factory.newEventSource(request, eventSourceListener(eventSourceListener));
    }

    @Override
    public ImageResponse genImages(ImageRequest imageRequest) throws Exception {
        return genImages(null, null, imageRequest);
    }

    @Override
    public ImageResponse genImages(String apiHostByUser, String apiKeyByUser, ImageRequest imageRequest) throws IOException {
        //1.统一转换参数
        BaiduImageRequest baiduImageRequest = BaiduImageRequest.builder()
                .n(imageRequest.getN())
                .size(imageRequest.getSize())
                .prompt(imageRequest.getPrompt())
                .build();

        String url = BaiduConfig.CompletionsUrl.fromCode(imageRequest.getModel()).getUrl();
        Request request = new Request.Builder()
                .addHeader("Content-Type", Configuration.APPLICATION_JSON)
                .url(BaiduConfig.API_HOST + url)
                .post(RequestBody.create(PlusJsonUtils.toJsonString(baiduImageRequest), MediaType.parse(Configuration.APPLICATION_JSON)))
                .build();

        Call call = okHttpClient.newCall(request);
        try (Response response = call.execute()) {
            if (response.isSuccessful() && response.body() != null) {
                return PlusJsonUtils.parseObject(response.body().string(), ImageResponse.class);
            } else {
                throw new IOException("Failed to get image response");
            }
        }
    }

    @Override
    public EventSource pictureUnderstanding(PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public EventSource pictureUnderstanding(String apiHostByUser, String apiKeyByUser, PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public BaiduCompletionRequest getParameterObject(CompletionRequest completionRequest) {

        // 信息转换，需要带上历史信息
        List<Message> wenXinMessages = new ArrayList<>();
        List<cn.bugstack.openai.executor.parameter.Message> messages = completionRequest.getMessages();
        BaiduCompletionRequest baiduCompletionRequest = new BaiduCompletionRequest();
        for (cn.bugstack.openai.executor.parameter.Message message : messages) {
            Message messageVo = new Message();
            if (CompletionRequest.Role.SYSTEM.getCode().equals(message.getRole())) {
                baiduCompletionRequest.setSystem(message.getContent());
                continue;
            }
            messageVo.setRole(message.getRole());
            messageVo.setContent(message.getContent());
            wenXinMessages.add(messageVo);
        }
        // 封装参数
        baiduCompletionRequest.setStream(completionRequest.isStream());
        baiduCompletionRequest.setTopP(completionRequest.getTopP());
        baiduCompletionRequest.setTemperature(completionRequest.getTemperature());
        baiduCompletionRequest.setMessages(wenXinMessages);

        return baiduCompletionRequest;
    }

    @Override
    public boolean supportFunction(String model) {
        return false;
    }


    @Override
    public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
        return new EventSourceListener() {
            @Override
            public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
                BaiduCompletionResponse response = PlusJsonUtils.parseObject(data, BaiduCompletionResponse.class);
                CompletionResponse completionResponse = new CompletionResponse();
                List<ChatChoice> choices = new ArrayList<>();
                ChatChoice chatChoice = new ChatChoice();
                chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder()
                        .name("")
                        .role(CompletionRequest.Role.SYSTEM)
                        .content(response.getResult())
                        .build());
                choices.add(chatChoice);
                completionResponse.setChoices(choices);
                // 未结束对话
                if (Boolean.FALSE.equals(response.getIsEnd())) {
                    eventSourceListener.onEvent(eventSource, id, type, PlusJsonUtils.toJsonString(completionResponse));
                } else {
                    // 封装额度
                    Usage usage = response.getUsage();
                    cn.bugstack.openai.executor.parameter.Usage openaiUsage = new cn.bugstack.openai.executor.parameter.Usage();
                    openaiUsage.setPromptTokens(usage.getPromptTokens());
                    openaiUsage.setCompletionTokens(usage.getCompletionTokens());
                    openaiUsage.setTotalTokens(usage.getTotalTokens());
                    // 封装结束
                    chatChoice.setFinishReason("stop");
                    chatChoice.setDelta(cn.bugstack.openai.executor.parameter.Message.builder()
                            .name("")
                            .role(CompletionRequest.Role.SYSTEM)
                            .content(response.getResult())
                            .build());
                    choices.add(chatChoice);
                    // 构建结果
                    completionResponse.setUsage(openaiUsage);
                    completionResponse.setCreated(System.currentTimeMillis());
                    // 返回数据
                    eventSourceListener.onEvent(eventSource, null, null, Objects.requireNonNull(PlusJsonUtils.toJsonString(completionResponse)));
                }
            }

            @Override
            public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
                eventSourceListener.onOpen(eventSource, response);
            }

            @Override
            public void onClosed(@NotNull EventSource eventSource) {
                eventSourceListener.onClosed(eventSource);
            }

            @Override
            public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
                eventSourceListener.onFailure(eventSource, t, response);
            }
        };
    }
}
